clear all
%close all
%addpath('functions')
%addpath('subroutines/sub_Table')
addpath('sub_Table')
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Setting: Parameters
% Generating artificial data from non-homothetic preference
% 1:AIDS, 2:Nh-CES (flag for type of preference)
paras.preference_type = 2;
paras.flag_sigma = 0; % 0 for constant sigma, 1 for variable sigma
N = 1000;
I = 3;% Need to adjust it when the number of goods changed.
T = 40;
pols = [1,2,4,6,8,12];% JL polynomials
name = 'A';
% The indirect utility of Nh-CES is calculated numerically from wages and prices.
% To make it easier to find a solution, we solve for indirect utility for Ns people, one for each t.
% Therefore, N must be a multiple of Ns.
Ns = 20;
% Nh-CES parameters
if paras.flag_sigma == 0
    paras.sig = 5;
else
    paras.sig = 10;
end
paras.eta = -2;% from Auer et.al. (2021)
paras.eps_vec = [0.3; 1; 2];%./(1-paras.sig)*1;

%paras.eps_vec = [0.2;1;1.65];% Non-homothetic parametrization from Comin et.al (2021)
%paras.eps_vec = [1;1;1];% Homotheric
paras.T = T;paras.T = T;paras.N = N;paras.I = I;paras.Ns = Ns;

%% Setting: Price and Wage
vecpratio=[2 3 4];
vecp0=[1 1 1];
vecpgrowth=log(vecpratio)/(T-1);
minI=1;
maxI=1.1;

err=1e-8;
for t=1:T
    matI(t,:)=linspace(minI,maxI,N);
    matp(t,:)=vecp0.*exp(vecpgrowth*(t-1));
    t;
end
I_vec = reshape(matI',[1,N,T]);
pvec = reshape(matp',[I,1,T]);

paras.omega = [1;1;1];
paras.uvec_init = I_vec/10; % Initial value used to calculate the indirect utility of Comin.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Generating Artificial data %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[v_vec, u_vec, B_vec,flag]= NhCESmax(I_vec,pvec,paras) ;


%% JLloop

format shortG

%figure('Position',[100 100 800 400],'NumberTitle','off')
for i = 1:size(pols,2)
pol= pols(i);    
paras.Kn = pol;
warning("off")
 lastwarn('');
[LJMM_F,LJMM_SC,JLerrorType] = CalLJ_poly(I_vec, B_vec, pvec,paras);

run("sub_errJL.m")
ErrorJLVec(i,:) = ErrorJL;

JLErrorTypeVec(i,:) = JLerrorType;

end


%%  Produce Tables and fig
caption_text = 'JL Algorithm';
latex_table = create_latex_table_JL(ErrorJLVec,pols, JLErrorTypeVec, caption_text)
%save_latex_table_to_file(latex_table, ['../fig/' name '_JL.tex']);

function [v_vec, u_vec, B_vec,flagok]= NhCESmax(I_vec,pvec,paras) 
%%Nonhomothetic Preference: Given Wfull and P, solve for u 
T = paras.T; 
N = paras.N;
I = paras.I;
Ns = paras.Ns;
v_vec = zeros(1,N,T);
uvec_init = paras.uvec_init;


%[omega_vec,err] = fmincon(@(x) norm(CalOmega_nested(x,medianw,paras)), omega_init, [], [], [], [],[],[], [], options);
for t = 1:T
    for ns = 1: Ns   
        Index = (N/Ns)*(ns-1)+1:(N/Ns)*(ns);        
            options = optimoptions('fsolve', 'TolFun', 1e-8, 'Display', 'off', 'Algorithm', 'trust-region');%,'Algorithm','Levenberg-Marquardt');
            [ v_vec(:,Index,t),fval,flag(:,Index,t)]  =  fsolve( @(x) NHU(x,I_vec(:,Index,t),pvec(:,:,t),paras),uvec_init(:,Index),options);  % pid

        if sum(flag(:,Index,t)) < 0
            options = optimoptions('fmincon','Algorithm','interior-point','OptimalityTolerance',1e-8,'Display','off');
            [ v_vec(:,Index,t),fval,flag(:,Index,t)]  = fmincon(@(x) norm(NHU(x,I_vec(:,Index,t),pvec(:,:,t),paras)), uvec_init(:,Index), [], [], [], [],[],[], [], options);
        end
        %[ v_vec(:,Index,t),fval]  =  fmincon(obj_fun, uvec_init(:,Index), A, b, Aeq, beq, lb, ub, nonlcon, options);
    end 
end

flagok = max(max(flag < 0));
%%Obtain EV
[u_vec,B_vec]= CalEV(v_vec,pvec,paras);

function err = NHU(v_vec,I_vec,pvec,paras) 
if paras.flag_sigma ==1
    sig0 = paras.sig;
    xi = 0.01;
    xi2= 100;
    paras.sig = ((1.5)^((xi2-1))+((5^((xi-1)/xi)+(10-2*log(v_vec)).^((xi-1)/xi)).^(xi/(xi-1))).^((xi2-1))).^(1/(xi2-1));
end
E = sum(paras.omega.*(((v_vec.^paras.eps_vec).*pvec).^(1-paras.sig))).^(1./(1-paras.sig));
err = abs(E-I_vec);
end

function [u_vec,B_vec]= CalEV(v_vec,pvec,paras)
paras.sig =paras.sig;
sig0 = paras.sig;
if paras.flag_sigma ==1
    xi = 0.01;
    xi2= 100;
    paras.sig = ((1.5)^((xi2-1))+((5^((xi-1)/xi)+(10-2*log(v_vec)).^((xi-1)/xi)).^(xi/(xi-1))).^((xi2-1))).^(1/(xi2-1));
end
pvec_b = pvec(:,:,1);
u_vec = sum(paras.omega.*(((v_vec.^paras.eps_vec).*pvec_b).^(1-paras.sig))).^(1./(1-paras.sig));
tmp = paras.omega.*(((v_vec.^paras.eps_vec).*pvec).^(1-paras.sig));
B_vec = tmp./sum(tmp);
end

end